function [r_history,s_history,beta,obj,rho] = MIDA(xlInst,xlIdx,report,yl,rho,lambda1,lambda2,c,maxIter,maxCount,beta0)
%   Solve L1 regularized multi-instance domain adpatation via ADMM and CCP
%
% [r_history,s_history,beta,obj,rho] = MIDA(xlInst,xlIdx,report,yl,rho,lambda1,lambda2,c,maxIter,maxCount,beta0)
% input:
% xlInst: n*k matrix, a collection of tweet vectors from labeled users. 
% Each  row represents a tweet and each column represents the count of a keyword.
% xlIdx: n*1 vector,  an index vector which maps tweets to users.
% report: r*k matrix, a collection of formal reports vectors. 
% Each  row  represents a formal report and each column represents the count of a keyword.
% rho: the augmented Lagrangian parameter.
% lambda1 and lambda2: two tuning parameters.
% c: number of partition.
% maxIter: maximal numbers of iteration.
% maxCount: maximal numbers of iteration to terminate iterations if primal residual r and dual residual s didn't descrease.
% beta0(optional): (k+1)*1 vector, initilization of beta.
% where:
% n=number of tweets.
% k=number of keywords.
% r=number of formal reports.
% output:
% r_history: a history record of primal residual r.
% s_history: a history record of dual residual s.
% beta: the objective parameter.
% obj: the  objective values for correpsonding optimal beta.
% rho: the augmented Lagrangian parameter after updating.
k = size(xlInst,2);
user_num=length(yl);
pos_num=sum(yl==1);
TOLERANCE1   = 10e-3;
TOLERANCE2   = 10e-3;
report = [ones(size(report,1),1),report];
%preprocessing instance sets; 
xlInst =[ones(size(xlInst,1),1),xlInst];
idx =xlIdx;
inst =xlInst;
%initialization
if exist('beta0','var')
    beta = beta0;
else
    beta=zeros(k+1,1);
    beta(1)=-1;
end
tweet_num =size(inst,1);
report_num=size(report,1);
h = zeros(tweet_num,1);
r = 0;
s = 0;
r_history =[];
s_history =[];
mu =4;
rMin =9999;
sMin=9999;
count =0;
% partition formal report sets and tweet sets of positive users.
rng(100);
report_indices=crossvalind('Kfold', report_num, c);
rng(100);
pos_indices = crossvalind('Kfold', pos_num, c);   
I=zeros(1,user_num);
for iter = 1:maxIter
    iter
        %rho-update
%              if(r>mu*s)
%                  rho =rho*2;
%              else if (mu*r<s)
%                      rho = rho/2;
%                  end
%              end
    %rho =rho +2/MAX_ITER;
    % choose index set
    temp=inst*beta;
    pos_inst=[];
    for i=1:user_num
        t=temp(idx==i);
        I(i)=find(t==max(t),1,'first');
        t=inst(idx==i,:);
        if(yl(i)==1)
            pos_inst=[pos_inst;t(I(i),:)];
        end
    end
    % construct distance measures in the first iteration
        report_tweet=[];
        tweet_tweet=[];
        for j=1:c            
            report_data=report(report_indices==j,:);
            tweet_data=pos_inst(pos_indices==j,:);
            report_fold_num=sum(report_indices==j);
            tweet_fold_num=sum(pos_indices==j);
            temp1=zeros(report_fold_num*tweet_fold_num,k+1);
            temp2=zeros(tweet_fold_num*(tweet_fold_num-1)/2,k+1);
        for i=1:(tweet_fold_num-1)
            remain_num=tweet_fold_num-i;
            temp1(((i-1)*report_fold_num+1):(i*report_fold_num),:)=(repmat(tweet_data(i,:),report_fold_num,1)-report_data).^2;
            temp2(((i-1)*remain_num+1):(i*remain_num),:)=(repmat(tweet_data(i,:),remain_num,1)-tweet_data(i+1:end,:)).^2;
        end
        report_tweet=[report_tweet;temp1];
        tweet_tweet=[tweet_tweet;temp2];
        end
    % S-update
    S = update_S(yl,temp,h,rho,I,idx);
    beta_old=beta;
    % beta-update
    beta = update_beta(inst,report_tweet,tweet_tweet,h,rho,lambda1,lambda2,beta,S,report_num,pos_num);
    % compute prime and dual residuals
    r = S-inst*beta;
    s =rho*inst*(beta_old-beta);
    h = h + r;
    r =norm(r);
    r_history(iter)=r;
    s =norm(s);
    s_history(iter)=s;
    % record the smallest r and s
        if ( r<rMin)
            rMin=r;
            count =0;
        else if (s<sMin)
                sMin =s;
                count =0;
            else
                count = count +1;
            end
        end
    %  termination checks
    if (r<=TOLERANCE1 &&  s<=TOLERANCE2) || count ==maxCount
        break;
    end
end
% compute objective value
obj =objective(S,yl,lambda1,lambda2,beta,I,report_tweet,tweet_tweet,idx,report_num,pos_num);
end
function obj =objective(S,yl,lambda1,lambda2,beta,I,report_tweet,tweet_tweet,idx,report_num,pos_num)
n=length(yl);
t=zeros(1,n);
for i=1:n
    tt=S(idx==i);
    t(i)=tt(I(i));
end
obj=sum(log(1+exp(t)))-t*yl+lambda1*norm(beta(2:end),1)+2*lambda2*(sum(report_tweet*beta.^2)/report_num/pos_num)-2*lambda2*sum(tweet_tweet*beta.^2)/pos_num^2;
end
function S = update_S(yl,temp,h,rho,I,idx)
% solve the S-update
% via FISTA
n = length(yl);
S = temp-h;
temp1=zeros(1,n);
h1 =zeros(1,n);
for i=1:n
    tt=temp(idx==i);
    temp1(i)=tt(I(i));
    tt=h(idx==i);
    h1(i)=tt(I(i));
end
t =zeros(1,n);
yt = 10e10;
MAX_ITER = 500;
TOLERANCE =10e-5;
y = @(t) (sum(log(1 + exp(t)))-t*yl + (rho/2)*norm(t - temp1 + h1,2)^2);
lambda =1;
zeta =t;
eta = 0.5;
% FISTA
for iter = 1:MAX_ITER
    yt_old =yt;
    yt = y(t);
    if(abs(yt-yt_old)<TOLERANCE)
        break;
    end
    lambda_old =lambda;
    lambda =(1+sqrt(1+4*lambda^2))/2;
    gamma =(1-lambda_old)/lambda;
    gradient=exp(t)./(exp(t)+1)-yl';
    zeta_old =zeta;
    zeta = (rho*(temp1-h1)+(t-eta*gradient)/eta)/(rho+1/eta);
    t =(1-gamma)*zeta+gamma*zeta_old;
end
for i=1:n
    tt=S(idx==i);
    tt(I(i))=t(i);
    S(idx==i)=tt;
end
end

function beta = update_beta(inst,report_tweet,tweet_tweet,h,rho,lambda1,lambda2,beta0,S,report_num,pos_num)
% solve the beta-update via CCP
beta=beta0;
f=10e10;
MAX_ITER1 = 50;
MAX_ITER2 = 50;
TOLERANCE =10e-5;
lambda =1;
eta = 10e-8;
ybeta = 10e10;
for iter1=1:MAX_ITER1
beta_q=beta;
y = @(beta) lambda1*norm(beta(2:end),1)+2*lambda2*sum(report_tweet*beta.^2/report_num/pos_num)+rho/2*norm(S-inst*beta+h,2)-2*lambda2*sum(tweet_tweet*beta_q.^2/pos_num^2)-4*lambda2*beta_q'.*sum(tweet_tweet,1)/pos_num^2*(beta-beta_q);
f_old=f;
f=y(beta);
if(abs(f-f_old)<TOLERANCE)
    break;
end
zeta =beta;
% FISTA
for iter2 = 1:MAX_ITER2
    ybeta_old =ybeta;
    ybeta = y(beta);
    if(abs(ybeta-ybeta_old)<TOLERANCE)
        break;
    end
    lambda_old =lambda;
    lambda =(1+sqrt(1+4*lambda^2))/2;
    gamma =(1-lambda_old)/lambda;
    gradient=4*lambda2*sum(report_tweet/report_num/pos_num,1)'.*beta-rho*inst'*(S-inst*beta+h)-4*lambda2*sum(tweet_tweet/pos_num^2,1)'.*beta_q;
    zeta_old =zeta;
    zeta(2:end) = soft(beta(2:end)-eta*gradient(2:end),eta*lambda1);
    zeta(1)=beta(1)-eta*gradient(1);
    beta =(1-gamma)*zeta+gamma*zeta_old;
end
end
end
function zeta=soft(x,alpha)
zeta=max(abs(x)-alpha,0).*sign(x);
end